Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EfficientNet #113

Closed
wants to merge 5 commits into from
Closed

Implement EfficientNet #113

wants to merge 5 commits into from

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Feb 7, 2022

This is an implementation of EfficientNet model in a similar way to PyTorch EfficientNet.

TODO

  • Documentation

@pxl-th
Copy link
Member Author

pxl-th commented Feb 7, 2022

As for the pretrained weights, we can use utilities to load PyTorch's weights to the model, save as .bson files and ship them as artifacts.

@DhairyaLGandhi DhairyaLGandhi mentioned this pull request Feb 8, 2022
46 tasks
@pxl-th
Copy link
Member Author

pxl-th commented Feb 11, 2022

Wanted to add pretrained weights, but looks like Metalhead's utilities for loading weights only include trainable parameters:

loadpretrain!(model, name) = Flux.loadparams!(model, weights(name))

But it is improtant to load moving mean and variance for the BatchNorm.
Otherwise accurancy of the model decreases significantly and does no match the PyTorch's version.
Maybe we should instead save the whole model directly, like Flux documentation suggests?

@pxl-th pxl-th marked this pull request as ready for review February 11, 2022 22:23
@pxl-th
Copy link
Member Author

pxl-th commented Feb 11, 2022

Additionally, maybe we should start adding inferrability tests for the models? (in light of Taking TTFX seriously)
However to make models that use BatchNorm type-stable we'd need a new release of Flux.jl with FluxML/Flux.jl#1856.

@ToucheSir
Copy link
Member

Maybe we should instead save the whole model directly, like Flux documentation suggests?

Based on FluxML/MetalheadWeights#2, that seems to be the way to go.

@darsnack
Copy link
Member

darsnack commented Feb 18, 2022

Thank you for this PR!

I have filed FluxML/Flux.jl#1875 to address the issue about loading state. Along with FluxML/MetalheadWeights#2, we should be able to handle models saved with @save "name.bson" model.

The implementation in this PR is pretty much good to go. It mostly needs to be refactored to match the model building style of this package, and it needs to reuse existing functions. Based on this, I suggest the following (this is not a complete rewrite—the internal code is the same—it was just too annoying to input as a GH review comment):

function efficientnet(imsize, scalings, block_config;
                      inchannels = 3, nclasses = 1000, max_width = 1280)
  wscale, dscale = scalings
  out_channels = _round_channels(32, 8)
  stem = Chain(Conv((3, 3), inchannels => out_channels; bias = false, stride = 2, SamePad()),
               BatchNorm(out_channels, swish))

  blocks = []
  for (n, k, s, e, i, o) in block_config
    in_channels = round_filter(i, 8)
    out_channels = round_filter(o, 8)
    repeat = dscale  1 ? n : ceil(Int64, dscale * n)

    push!(blocks, invertedresidual(k, in_channels, in_channels * e, out_channels, swish;
                                   stride = s, reduction = 4))
    for _ in 1:(repeat - 1)
      push!(blocks, invertedresidual(k, out_channels, out_channels * e, out_channels, swish;
                                     stride = 1, reduction = 4))
    end
  end
  blocks = Chain(blocks...)

  head_out_channels = _round_channels(max_width, 8)
  head = Chain(Conv((1, 1), out_channels => head_out_channels; bias = false, pad = SamePad()),
               BatchNorm(head_out_channels, swish))

  top = Dense(head_out_channels, nclasses)

  return Chain(Chain(stem, blocks, head),
               Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top))
end

# n: # of block repetitions
# k: kernel size k x k
# s: stride
# e: expantion ratio
# i: block input channels
# o: block output channels
const efficinet_block_configs = [
# (n, k, s, e,   i,   o)
  (1, 3, 1, 1,  32,  16),
  (2, 3, 2, 6,  16,  24),
  (2, 5, 2, 6,  24,  40),
  (3, 3, 2, 6,  40,  80),
  (3, 5, 1, 6,  80, 112),
  (4, 5, 2, 6, 112, 192),
  (1, 3, 1, 6, 192, 320)
]

# w: width scaling
# d: depth scaling
# r: image resolution
const efficient_global_configs = Dict(
#        (  r, (  w,   d))
  :b0 => (224, (1.0, 1.0)),
  :b1 => (240, (1.0, 1.1)),
  :b2 => (260, (1.1, 1.2)),
  :b3 => (300, (1.2, 1.4)),
  :b4 => (380, (1.4, 1.8)),
  :b5 => (456, (1.6, 2.2)),
  :b6 => (528, (1.8, 2.6)),
  :b7 => (600, (2.0, 3.1)),
  :b8 => (672, (2.2, 3.6))
)

struct EfficientNet
  layers
end

function EfficientNet(imsize, scalings, block_config;
                      inchannels = 3, nclasses = 1000, max_width = 1280)
  layers = efficientnet(imsize, scalings, block_config;
                        inchannels = inchannels, nclasses = nclasses, max_width = max_width)
  EfficientNet(layers)
end

@functor EfficientNet

(m::EfficientNet)(x) = m.layers(x)

backbone(m::EfficientNet) = m.layers[1]
classifier(m::EfficientNet) = m.layers[2]

function EfficientNet(name::Symbol; pretrain = false)
  @assert name in keys(efficient_global_configs)
    "`name` must be one of $(sort(collect(keys(efficient_global_configs))))"

  model = EfficientNet(efficient_global_configs[name]..., efficinet_block_configs)
  pretrain && loadpretrain!(model, string("EfficientNet", name))

  return model
end

Note that this requires a rebase to pass since it depends on #120. The above code can be under src/convnets/efficientnet.jl (we no longer need MBConv or the parameter structs).

@darsnack
Copy link
Member

darsnack commented Apr 8, 2022

@pxl-th any interest in reviving this PR with the feedback above? If not, is it okay if I continue off this PR to complete it?

@pxl-th
Copy link
Member Author

pxl-th commented Apr 12, 2022

Sorry, I don't have a lot of free time currently. It is totally ok if you are interested in completing this PR :)

@darsnack
Copy link
Member

Superseded by #171.

@darsnack darsnack closed this Jun 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants